-- SPDX-License-Identifier: MIT
-- Copyright (C) 2018-present iced project and contributors

describe("Register", function()
	local Register = require("iced_x86.Register")
	local RegisterExt = require("iced_x86.RegisterExt")
	local RegisterInfo = require("iced_x86.RegisterInfo")

	it("reg ext", function()
		assert.equals(Register.AL, RegisterExt.base(Register.DL))
		assert.equals(Register.AX, RegisterExt.base(Register.R8W))
		assert.equals(Register.EAX, RegisterExt.base(Register.R15D))
		assert.equals(Register.RAX, RegisterExt.base(Register.R13))
		assert.equals(Register.ES, RegisterExt.base(Register.FS))
		assert.equals(Register.XMM0, RegisterExt.base(Register.XMM2))
		assert.equals(Register.YMM0, RegisterExt.base(Register.YMM20))
		assert.equals(Register.ZMM0, RegisterExt.base(Register.ZMM31))

		assert.equals(2, RegisterExt.number(Register.DL))
		assert.equals(15, RegisterExt.number(Register.R15))
		assert.equals(21, RegisterExt.number(Register.YMM21))

		assert.equals(Register.RCX, RegisterExt.full_register(Register.CL))
		assert.equals(Register.RDX, RegisterExt.full_register(Register.DX))
		assert.equals(Register.RBX, RegisterExt.full_register(Register.EBX))
		assert.equals(Register.RSP, RegisterExt.full_register(Register.RSP))
		assert.equals(Register.ZMM2, RegisterExt.full_register(Register.XMM2))
		assert.equals(Register.ZMM22, RegisterExt.full_register(Register.YMM22))
		assert.equals(Register.ZMM11, RegisterExt.full_register(Register.ZMM11))

		assert.equals(Register.ECX, RegisterExt.full_register32(Register.CL))
		assert.equals(Register.EDX, RegisterExt.full_register32(Register.DX))
		assert.equals(Register.EBX, RegisterExt.full_register32(Register.EBX))
		assert.equals(Register.ESP, RegisterExt.full_register32(Register.RSP))
		assert.equals(Register.ZMM2, RegisterExt.full_register32(Register.XMM2))
		assert.equals(Register.ZMM22, RegisterExt.full_register32(Register.YMM22))
		assert.equals(Register.ZMM11, RegisterExt.full_register32(Register.ZMM11))

		assert.equals(1, RegisterExt.size(Register.DL))
		assert.equals(2, RegisterExt.size(Register.R8W))
		assert.equals(4, RegisterExt.size(Register.R15D))
		assert.equals(8, RegisterExt.size(Register.R13))
		assert.equals(2, RegisterExt.size(Register.FS))
		assert.equals(16, RegisterExt.size(Register.XMM2))
		assert.equals(32, RegisterExt.size(Register.YMM20))
		assert.equals(64, RegisterExt.size(Register.ZMM31))

		assert.is_false(RegisterExt.is_segment_register(Register.CX))
		assert.is_true(RegisterExt.is_segment_register(Register.GS))

		assert.is_true(RegisterExt.is_gpr(Register.CL))
		assert.is_true(RegisterExt.is_gpr(Register.DX))
		assert.is_true(RegisterExt.is_gpr(Register.ESP))
		assert.is_true(RegisterExt.is_gpr(Register.R15))
		assert.is_false(RegisterExt.is_gpr(Register.ES))

		assert.is_true(RegisterExt.is_gpr8(Register.CL))
		assert.is_false(RegisterExt.is_gpr8(Register.DX))
		assert.is_false(RegisterExt.is_gpr8(Register.ESP))
		assert.is_false(RegisterExt.is_gpr8(Register.R15))
		assert.is_false(RegisterExt.is_gpr8(Register.ES))

		assert.is_false(RegisterExt.is_gpr16(Register.CL))
		assert.is_true(RegisterExt.is_gpr16(Register.DX))
		assert.is_false(RegisterExt.is_gpr16(Register.ESP))
		assert.is_false(RegisterExt.is_gpr16(Register.R15))
		assert.is_false(RegisterExt.is_gpr16(Register.ES))

		assert.is_false(RegisterExt.is_gpr32(Register.CL))
		assert.is_false(RegisterExt.is_gpr32(Register.DX))
		assert.is_true(RegisterExt.is_gpr32(Register.ESP))
		assert.is_false(RegisterExt.is_gpr32(Register.R15))
		assert.is_false(RegisterExt.is_gpr32(Register.ES))

		assert.is_false(RegisterExt.is_gpr64(Register.CL))
		assert.is_false(RegisterExt.is_gpr64(Register.DX))
		assert.is_false(RegisterExt.is_gpr64(Register.ESP))
		assert.is_true(RegisterExt.is_gpr64(Register.R15))
		assert.is_false(RegisterExt.is_gpr64(Register.ES))

		assert.is_false(RegisterExt.is_vector_register(Register.CL))
		assert.is_true(RegisterExt.is_vector_register(Register.XMM1))
		assert.is_true(RegisterExt.is_vector_register(Register.YMM2))
		assert.is_true(RegisterExt.is_vector_register(Register.ZMM3))

		assert.is_false(RegisterExt.is_xmm(Register.CL))
		assert.is_true(RegisterExt.is_xmm(Register.XMM1))
		assert.is_false(RegisterExt.is_xmm(Register.YMM2))
		assert.is_false(RegisterExt.is_xmm(Register.ZMM3))

		assert.is_false(RegisterExt.is_ymm(Register.CL))
		assert.is_false(RegisterExt.is_ymm(Register.XMM1))
		assert.is_true(RegisterExt.is_ymm(Register.YMM2))
		assert.is_false(RegisterExt.is_ymm(Register.ZMM3))

		assert.is_false(RegisterExt.is_zmm(Register.CL))
		assert.is_false(RegisterExt.is_zmm(Register.XMM1))
		assert.is_false(RegisterExt.is_zmm(Register.YMM2))
		assert.is_true(RegisterExt.is_zmm(Register.ZMM3))

		assert.is_false(RegisterExt.is_ip(Register.CL))
		assert.is_true(RegisterExt.is_ip(Register.EIP))
		assert.is_true(RegisterExt.is_ip(Register.RIP))

		assert.is_false(RegisterExt.is_k(Register.CL))
		assert.is_true(RegisterExt.is_k(Register.K3))

		assert.is_false(RegisterExt.is_cr(Register.CL))
		assert.is_true(RegisterExt.is_cr(Register.CR3))

		assert.is_false(RegisterExt.is_dr(Register.CL))
		assert.is_true(RegisterExt.is_dr(Register.DR3))

		assert.is_false(RegisterExt.is_tr(Register.CL))
		assert.is_true(RegisterExt.is_tr(Register.TR3))

		assert.is_false(RegisterExt.is_st(Register.CL))
		assert.is_true(RegisterExt.is_st(Register.ST3))

		assert.is_false(RegisterExt.is_bnd(Register.CL))
		assert.is_true(RegisterExt.is_bnd(Register.BND3))

		assert.is_false(RegisterExt.is_mm(Register.CL))
		assert.is_true(RegisterExt.is_mm(Register.MM3))

		assert.is_false(RegisterExt.is_tmm(Register.CL))
		assert.is_true(RegisterExt.is_tmm(Register.TMM3))
	end)

	it("create", function()
		local fns = {
			function(register)
				return RegisterExt.info(register)
			end,
			function(register)
				return RegisterInfo.new(register)
			end,
		}
		for _, create in ipairs(fns) do
			---@type RegisterInfo
			local info = create(Register.R10D)
			assert.equals(Register.R10D, info:register())
			assert.equals(Register.EAX, info:base())
			assert.equals(10, info:number())
			assert.equals(Register.R10, info:full_register())
			assert.equals(Register.R10D, info:full_register32())
			assert.equals(4, info:size())
		end
	end)

	-- stylua: ignore
	it("invalid", function()
		assert.has_error(function() RegisterExt.info(0x789A) end)
		assert.has_error(function() RegisterExt.info(-0x80000001) end)
		assert.has_error(function() RegisterExt.info(0x100000000) end)
		assert.has_error(function() RegisterExt.base(0x789A) end)
		assert.has_error(function() RegisterExt.base(-0x80000001) end)
		assert.has_error(function() RegisterExt.base(0x100000000) end)
		assert.has_error(function() RegisterExt.number(0x789A) end)
		assert.has_error(function() RegisterExt.number(-0x80000001) end)
		assert.has_error(function() RegisterExt.number(0x100000000) end)
		assert.has_error(function() RegisterExt.full_register(0x789A) end)
		assert.has_error(function() RegisterExt.full_register(-0x80000001) end)
		assert.has_error(function() RegisterExt.full_register(0x100000000) end)
		assert.has_error(function() RegisterExt.full_register32(0x789A) end)
		assert.has_error(function() RegisterExt.full_register32(-0x80000001) end)
		assert.has_error(function() RegisterExt.full_register32(0x100000000) end)
		assert.has_error(function() RegisterExt.size(0x789A) end)
		assert.has_error(function() RegisterExt.size(-0x80000001) end)
		assert.has_error(function() RegisterExt.size(0x100000000) end)
		assert.has_error(function() RegisterExt.is_segment_register(0x789A) end)
		assert.has_error(function() RegisterExt.is_segment_register(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_segment_register(0x100000000) end)
		assert.has_error(function() RegisterExt.is_gpr(0x789A) end)
		assert.has_error(function() RegisterExt.is_gpr(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_gpr(0x100000000) end)
		assert.has_error(function() RegisterExt.is_gpr8(0x789A) end)
		assert.has_error(function() RegisterExt.is_gpr8(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_gpr8(0x100000000) end)
		assert.has_error(function() RegisterExt.is_gpr16(0x789A) end)
		assert.has_error(function() RegisterExt.is_gpr16(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_gpr16(0x100000000) end)
		assert.has_error(function() RegisterExt.is_gpr32(0x789A) end)
		assert.has_error(function() RegisterExt.is_gpr32(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_gpr32(0x100000000) end)
		assert.has_error(function() RegisterExt.is_gpr64(0x789A) end)
		assert.has_error(function() RegisterExt.is_gpr64(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_gpr64(0x100000000) end)
		assert.has_error(function() RegisterExt.is_xmm(0x789A) end)
		assert.has_error(function() RegisterExt.is_xmm(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_xmm(0x100000000) end)
		assert.has_error(function() RegisterExt.is_ymm(0x789A) end)
		assert.has_error(function() RegisterExt.is_ymm(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_ymm(0x100000000) end)
		assert.has_error(function() RegisterExt.is_zmm(0x789A) end)
		assert.has_error(function() RegisterExt.is_zmm(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_zmm(0x100000000) end)
		assert.has_error(function() RegisterExt.is_vector_register(0x789A) end)
		assert.has_error(function() RegisterExt.is_vector_register(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_vector_register(0x100000000) end)
		assert.has_error(function() RegisterExt.is_ip(0x789A) end)
		assert.has_error(function() RegisterExt.is_ip(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_ip(0x100000000) end)
		assert.has_error(function() RegisterExt.is_k(0x789A) end)
		assert.has_error(function() RegisterExt.is_k(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_k(0x100000000) end)
		assert.has_error(function() RegisterExt.is_cr(0x789A) end)
		assert.has_error(function() RegisterExt.is_cr(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_cr(0x100000000) end)
		assert.has_error(function() RegisterExt.is_dr(0x789A) end)
		assert.has_error(function() RegisterExt.is_dr(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_dr(0x100000000) end)
		assert.has_error(function() RegisterExt.is_tr(0x789A) end)
		assert.has_error(function() RegisterExt.is_tr(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_tr(0x100000000) end)
		assert.has_error(function() RegisterExt.is_st(0x789A) end)
		assert.has_error(function() RegisterExt.is_st(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_st(0x100000000) end)
		assert.has_error(function() RegisterExt.is_bnd(0x789A) end)
		assert.has_error(function() RegisterExt.is_bnd(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_bnd(0x100000000) end)
		assert.has_error(function() RegisterExt.is_mm(0x789A) end)
		assert.has_error(function() RegisterExt.is_mm(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_mm(0x100000000) end)
		assert.has_error(function() RegisterExt.is_tmm(0x789A) end)
		assert.has_error(function() RegisterExt.is_tmm(-0x80000001) end)
		assert.has_error(function() RegisterExt.is_tmm(0x100000000) end)
		assert.has_error(function() RegisterInfo.new(0x789A) end)
		assert.has_error(function() RegisterInfo.new(-0x80000001) end)
		assert.has_error(function() RegisterInfo.new(0x100000000) end)
	end)
end)
